未验证 提交 f9e1ef54 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #144 from 123malin/metrics

add metrics
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import abc import abc
import paddle.fluid as fluid
import numpy as np
class Metric(object): class Metric(object):
...@@ -21,27 +23,58 @@ class Metric(object): ...@@ -21,27 +23,58 @@ class Metric(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, config): def __init__(self, config):
""" """ """R
"""
pass pass
@abc.abstractmethod def clear(self, scope=None):
def clear(self, scope, params): """R
"""
clear current value
Args:
scope: value container
params: extend varilable for clear
""" """
pass if scope is None:
scope = fluid.global_scope()
@abc.abstractmethod place = fluid.CPUPlace()
def calculate(self, scope, params): for key in self._global_metric_state_vars:
varname, dtype = self._global_metric_state_vars[key]
var = scope.find_var(varname)
if not var:
continue
var = var.get_tensor()
data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place)
def _get_global_metric_state(self, fleet, scope, metric_name, mode="sum"):
"""R
""" """
calculate result var = scope.find_var(metric_name)
Args: if not var:
scope: value container return None
params: extend varilable for clear input = np.array(var.get_tensor())
if fleet is None:
return input
fleet._role_maker._barrier_worker()
old_shape = np.array(input.shape)
input = input.reshape(-1)
output = np.copy(input) * 0
fleet._role_maker._all_reduce(input, output, mode=mode)
output = output.reshape(old_shape)
return output
def calc_global_metrics(self, fleet, scope=None):
"""R
""" """
if scope is None:
scope = fluid.global_scope()
global_metrics = dict()
for key in self._global_metric_state_vars:
varname, dtype = self._global_metric_state_vars[key]
global_metrics[key] = self.get_global_metric_state(fleet, scope,
varname)
return self._calculate(global_metrics)
def _calculate(self, global_metrics):
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -52,7 +85,6 @@ class Metric(object): ...@@ -52,7 +85,6 @@ class Metric(object):
""" """
pass pass
@abc.abstractmethod
def __str__(self): def __str__(self):
""" """
Return: Return:
......
...@@ -11,3 +11,10 @@ ...@@ -11,3 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .recall_k import RecallK
from .pairwise_pn import PosNegRatio
from .precision_recall import PrecisionRecall
from .auc import AUC
__all__ = ['RecallK', 'PosNegRatio', 'AUC', 'PrecisionRecall']
...@@ -18,102 +18,60 @@ import numpy as np ...@@ -18,102 +18,60 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlerec.core.metric import Metric from paddlerec.core.metric import Metric
from paddle.fluid.layers.tensor import Variable
class AUCMetric(Metric): class AUC(Metric):
""" """
Metric For Fluid Model Metric For Fluid Model
""" """
def __init__(self, config, fleet): def __init__(self,
input,
label,
curve='ROC',
num_thresholds=2**12 - 1,
topk=1,
slide_steps=1):
""" """ """ """
self.config = config if not isinstance(input, Variable):
self.fleet = fleet raise ValueError("input must be Variable, but received %s" %
type(input))
def clear(self, scope, params): if not isinstance(label, Variable):
""" raise ValueError("label must be Variable, but received %s" %
Clear current metric value, usually set to zero type(label))
Args:
scope : paddle runtime var container auc_out, batch_auc_out, [
params(dict) : batch_stat_pos, batch_stat_neg, stat_pos, stat_neg
label : a group name for metric ] = fluid.layers.auc(input,
metric_dict : current metric_items in group label,
Return: curve=curve,
None num_thresholds=num_thresholds,
""" topk=topk,
self._label = params['label'] slide_steps=slide_steps)
self._metric_dict = params['metric_dict']
self._result = {} prob = fluid.layers.slice(input, axes=[1], starts=[1], ends=[2])
place = fluid.CPUPlace() label_cast = fluid.layers.cast(label, dtype="float32")
for metric_name in self._metric_dict: label_cast.stop_gradient = True
metric_config = self._metric_dict[metric_name] sqrerr, abserr, prob, q, pos, total = \
if scope.find_var(metric_config['var'].name) is None: fluid.contrib.layers.ctr_metric_bundle(prob, label_cast)
continue
metric_var = scope.var(metric_config['var'].name).get_tensor() self._global_metric_state_vars = dict()
data_type = 'float32' self._global_metric_state_vars['stat_pos'] = (stat_pos.name, "float32")
if 'data_type' in metric_config: self._global_metric_state_vars['stat_neg'] = (stat_neg.name, "float32")
data_type = metric_config['data_type'] self._global_metric_state_vars['total_ins_num'] = (total.name,
data_array = np.zeros(metric_var._get_dims()).astype(data_type) "float32")
metric_var.set(data_array, place) self._global_metric_state_vars['pos_ins_num'] = (pos.name, "float32")
self._global_metric_state_vars['q'] = (q.name, "float32")
def get_metric(self, scope, metric_name): self._global_metric_state_vars['prob'] = (prob.name, "float32")
""" self._global_metric_state_vars['abserr'] = (abserr.name, "float32")
reduce metric named metric_name from all worker self._global_metric_state_vars['sqrerr'] = (sqrerr.name, "float32")
Return:
metric reduce result self.metrics = dict()
""" self.metrics["AUC"] = auc_out
metric = np.array(scope.find_var(metric_name).get_tensor()) self.metrics["BATCH_AUC"] = batch_auc_out
old_metric_shape = np.array(metric.shape)
metric = metric.reshape(-1) def _calculate_bucket_error(self, global_pos, global_neg):
global_metric = np.copy(metric) * 0
self.fleet._role_maker.all_reduce_worker(metric, global_metric)
global_metric = global_metric.reshape(old_metric_shape)
return global_metric[0]
def get_global_metrics(self, scope, metric_dict):
"""
reduce all metric in metric_dict from all worker
Return:
dict : {matric_name : metric_result}
"""
self.fleet._role_maker._barrier_worker()
result = {}
for metric_name in metric_dict:
metric_item = metric_dict[metric_name]
if scope.find_var(metric_item['var'].name) is None:
result[metric_name] = None
continue
result[metric_name] = self.get_metric(scope,
metric_item['var'].name)
return result
def calculate_auc(self, global_pos, global_neg):
"""R
"""
num_bucket = len(global_pos)
area = 0.0
pos = 0.0
neg = 0.0
new_pos = 0.0
new_neg = 0.0
total_ins_num = 0
for i in range(num_bucket):
index = num_bucket - 1 - i
new_pos = pos + global_pos[index]
total_ins_num += global_pos[index]
new_neg = neg + global_neg[index]
total_ins_num += global_neg[index]
area += (new_neg - neg) * (pos + new_pos) / 2
pos = new_pos
neg = new_neg
auc_value = None
if pos * neg == 0 or total_ins_num == 0:
auc_value = 0.5
else:
auc_value = area / (pos * neg)
return auc_value
def calculate_bucket_error(self, global_pos, global_neg):
"""R """R
""" """
num_bucket = len(global_pos) num_bucket = len(global_pos)
...@@ -161,56 +119,69 @@ class AUCMetric(Metric): ...@@ -161,56 +119,69 @@ class AUCMetric(Metric):
bucket_error = error_sum / error_count if error_count > 0 else 0.0 bucket_error = error_sum / error_count if error_count > 0 else 0.0
return bucket_error return bucket_error
def calculate(self, scope, params): def _calculate_auc(self, global_pos, global_neg):
""" """ """R
self._label = params['label'] """
self._metric_dict = params['metric_dict'] num_bucket = len(global_pos)
self.fleet._role_maker._barrier_worker() area = 0.0
result = self.get_global_metrics(scope, self._metric_dict) pos = 0.0
neg = 0.0
new_pos = 0.0
new_neg = 0.0
total_ins_num = 0
for i in range(num_bucket):
index = num_bucket - 1 - i
new_pos = pos + global_pos[index]
total_ins_num += global_pos[index]
new_neg = neg + global_neg[index]
total_ins_num += global_neg[index]
area += (new_neg - neg) * (pos + new_pos) / 2
pos = new_pos
neg = new_neg
auc_value = None
if pos * neg == 0 or total_ins_num == 0:
auc_value = 0.5
else:
auc_value = area / (pos * neg)
return auc_value
def _calculate(self, global_metrics):
result = dict()
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
result[key] = global_metrics[key][0]
if result['total_ins_num'] == 0: if result['total_ins_num'] == 0:
self._result = result result['auc'] = 0
self._result['auc'] = 0 result['bucket_error'] = 0
self._result['bucket_error'] = 0 result['actual_ctr'] = 0
self._result['actual_ctr'] = 0 result['predict_ctr'] = 0
self._result['predict_ctr'] = 0 result['mae'] = 0
self._result['mae'] = 0 result['rmse'] = 0
self._result['rmse'] = 0 result['copc'] = 0
self._result['copc'] = 0 result['mean_q'] = 0
self._result['mean_q'] = 0 else:
return self._result result['auc'] = self._calculate_auc(result['stat_pos'],
if 'stat_pos' in result and 'stat_neg' in result: result['stat_neg'])
result['auc'] = self.calculate_auc(result['stat_pos'], result['bucket_error'] = self._calculate_bucket_error(
result['stat_neg']) result['stat_pos'], result['stat_neg'])
result['bucket_error'] = self.calculate_auc(result['stat_pos'],
result['stat_neg'])
if 'pos_ins_num' in result:
result['actual_ctr'] = result['pos_ins_num'] / result[ result['actual_ctr'] = result['pos_ins_num'] / result[
'total_ins_num'] 'total_ins_num']
if 'abserr' in result:
result['mae'] = result['abserr'] / result['total_ins_num'] result['mae'] = result['abserr'] / result['total_ins_num']
if 'sqrerr' in result:
result['rmse'] = math.sqrt(result['sqrerr'] / result['rmse'] = math.sqrt(result['sqrerr'] /
result['total_ins_num']) result['total_ins_num'])
if 'prob' in result:
result['predict_ctr'] = result['prob'] / result['total_ins_num'] result['predict_ctr'] = result['prob'] / result['total_ins_num']
if abs(result['predict_ctr']) > 1e-6: if abs(result['predict_ctr']) > 1e-6:
result['copc'] = result['actual_ctr'] / result['predict_ctr'] result['copc'] = result['actual_ctr'] / result['predict_ctr']
if 'q' in result:
result['mean_q'] = result['q'] / result['total_ins_num'] result['mean_q'] = result['q'] / result['total_ins_num']
self._result = result
return result
def get_result(self):
""" """
return self._result
def __str__(self): result_str = "AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f " \
""" """
result = self.get_result()
result_str = "%s AUC=%.6f BUCKET_ERROR=%.6f MAE=%.6f RMSE=%.6f " \
"Actural_CTR=%.6f Predicted_CTR=%.6f COPC=%.6f MEAN Q_VALUE=%.6f Ins number=%s" % \ "Actural_CTR=%.6f Predicted_CTR=%.6f COPC=%.6f MEAN Q_VALUE=%.6f Ins number=%s" % \
(self._label, result['auc'], result['bucket_error'], result['mae'], result['rmse'], (result['auc'], result['bucket_error'], result['mae'], result['rmse'],
result['actual_ctr'], result['actual_ctr'],
result['predict_ctr'], result['copc'], result['mean_q'], result['total_ins_num']) result['predict_ctr'], result['copc'], result['mean_q'], result['total_ins_num'])
return result_str return result_str
def get_result(self):
return self.metrics
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.metric import Metric
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class PosNegRatio(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, pos_score, neg_score):
""" """
kwargs = locals()
del kwargs['self']
helper = LayerHelper("PaddleRec_PosNegRatio", **kwargs)
if "pos_score" not in kwargs or "neg_score" not in kwargs:
raise ValueError(
"PosNegRatio expect pos_score and neg_score as inputs.")
pos_score = kwargs.get('pos_score')
neg_score = kwargs.get('neg_score')
if not isinstance(pos_score, Variable):
raise ValueError("pos_score must be Variable, but received %s" %
type(pos_score))
if not isinstance(neg_score, Variable):
raise ValueError("neg_score must be Variable, but received %s" %
type(neg_score))
wrong = fluid.layers.cast(
fluid.layers.less_equal(pos_score, neg_score), dtype='float32')
wrong_cnt = fluid.layers.reduce_sum(wrong)
right = fluid.layers.cast(
fluid.layers.less_than(neg_score, pos_score), dtype='float32')
right_cnt = fluid.layers.reduce_sum(right)
global_right_cnt, _ = helper.create_or_get_global_variable(
name="right_cnt", persistable=True, dtype='float32', shape=[1])
global_wrong_cnt, _ = helper.create_or_get_global_variable(
name="wrong_cnt", persistable=True, dtype='float32', shape=[1])
for var in [global_right_cnt, global_wrong_cnt]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
helper.append_op(
type="elementwise_add",
inputs={"X": [global_right_cnt],
"Y": [right_cnt]},
outputs={"Out": [global_right_cnt]})
helper.append_op(
type="elementwise_add",
inputs={"X": [global_wrong_cnt],
"Y": [wrong_cnt]},
outputs={"Out": [global_wrong_cnt]})
self.pn = (global_right_cnt + 1.0) / (global_wrong_cnt + 1.0)
self._global_metric_state_vars = dict()
self._global_metric_state_vars['right_cnt'] = (global_right_cnt.name,
"float32")
self._global_metric_state_vars['wrong_cnt'] = (global_wrong_cnt.name,
"float32")
self.metrics = dict()
self.metrics['WrongCnt'] = global_wrong_cnt
self.metrics['RightCnt'] = global_right_cnt
self.metrics['PN'] = self.pn
def _calculate(self, global_metrics):
for key in self._global_communicate_var:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
pn = (global_metrics['right_cnt'][0] + 1.0) / (
global_metrics['wrong_cnt'][0] + 1.0)
return "RightCnt=%s WrongCnt=%s PN=%s" % (
str(global_metrics['right_cnt'][0]),
str(global_metrics['wrong_cnt'][0]), str(pn))
def get_result(self):
return self.metrics
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.metric import Metric
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class PrecisionRecall(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, input, label, class_num):
"""R
"""
kwargs = locals()
del kwargs['self']
self.num_cls = class_num
if not isinstance(input, Variable):
raise ValueError("input must be Variable, but received %s" %
type(input))
if not isinstance(label, Variable):
raise ValueError("label must be Variable, but received %s" %
type(label))
helper = LayerHelper("PaddleRec_PrecisionRecall", **kwargs)
label = fluid.layers.cast(label, dtype="int32")
label.stop_gradient = True
max_probs, indices = fluid.layers.nn.topk(input, k=1)
indices = fluid.layers.cast(indices, dtype="int32")
indices.stop_gradient = True
states_info, _ = helper.create_or_get_global_variable(
name="states_info",
persistable=True,
dtype='float32',
shape=[self.num_cls, 4])
states_info.stop_gradient = True
helper.set_variable_initializer(
states_info, Constant(
value=0.0, force_cpu=True))
batch_metrics, _ = helper.create_or_get_global_variable(
name="batch_metrics",
persistable=False,
dtype='float32',
shape=[6])
accum_metrics, _ = helper.create_or_get_global_variable(
name="global_metrics",
persistable=False,
dtype='float32',
shape=[6])
batch_states = fluid.layers.fill_constant(
shape=[self.num_cls, 4], value=0.0, dtype="float32")
batch_states.stop_gradient = True
helper.append_op(
type="precision_recall",
attrs={'class_number': self.num_cls},
inputs={
'MaxProbs': [max_probs],
'Indices': [indices],
'Labels': [label],
'StatesInfo': [states_info]
},
outputs={
'BatchMetrics': [batch_metrics],
'AccumMetrics': [accum_metrics],
'AccumStatesInfo': [batch_states]
})
helper.append_op(
type="assign",
inputs={'X': [batch_states]},
outputs={'Out': [states_info]})
batch_states.stop_gradient = True
states_info.stop_gradient = True
self._global_metric_state_vars = dict()
self._global_metric_state_vars['states_info'] = (states_info.name,
"float32")
self.metrics = dict()
self.metrics["precision_recall_f1"] = accum_metrics
self.metrics["[TP FP TN FN]"] = states_info
def _calculate(self, global_metrics):
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
def calc_precision(tp_count, fp_count):
if tp_count > 0.0 or fp_count > 0.0:
return tp_count / (tp_count + fp_count)
return 1.0
def calc_recall(tp_count, fn_count):
if tp_count > 0.0 or fn_count > 0.0:
return tp_count / (tp_count + fn_count)
return 1.0
def calc_f1_score(precision, recall):
if precision > 0.0 or recall > 0.0:
return 2 * precision * recall / (precision + recall)
return 0.0
states = global_metrics["states_info"]
total_tp_count = 0.0
total_fp_count = 0.0
total_fn_count = 0.0
macro_avg_precision = 0.0
macro_avg_recall = 0.0
for i in range(self.num_cls):
total_tp_count += states[i][0]
total_fp_count += states[i][1]
total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = []
macro_avg_precision /= self.num_cls
macro_avg_recall /= self.num_cls
metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
metrics.append(micro_avg_precision)
micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
metrics.append(micro_avg_recall)
metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
return "total metrics: [TP, FP, TN, FN]=%s; precision_recall_f1=%s" % (
str(states), str(np.array(metrics).astype('float32')))
def get_result(self):
return self.metrics
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.metric import Metric
from paddle.fluid.layers import accuracy
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.tensor import Variable
class RecallK(Metric):
"""
Metric For Fluid Model
"""
def __init__(self, input, label, k=20):
""" """
kwargs = locals()
del kwargs['self']
self.k = k
if not isinstance(input, Variable):
raise ValueError("input must be Variable, but received %s" %
type(input))
if not isinstance(label, Variable):
raise ValueError("label must be Variable, but received %s" %
type(label))
helper = LayerHelper("PaddleRec_RecallK", **kwargs)
batch_accuracy = accuracy(input, label, self.k)
global_ins_cnt, _ = helper.create_or_get_global_variable(
name="ins_cnt", persistable=True, dtype='float32', shape=[1])
global_pos_cnt, _ = helper.create_or_get_global_variable(
name="pos_cnt", persistable=True, dtype='float32', shape=[1])
for var in [global_ins_cnt, global_pos_cnt]:
helper.set_variable_initializer(
var, Constant(
value=0.0, force_cpu=True))
tmp_ones = fluid.layers.fill_constant(
shape=fluid.layers.shape(label), dtype="float32", value=1.0)
batch_ins = fluid.layers.reduce_sum(tmp_ones)
batch_pos = batch_ins * batch_accuracy
helper.append_op(
type="elementwise_add",
inputs={"X": [global_ins_cnt],
"Y": [batch_ins]},
outputs={"Out": [global_ins_cnt]})
helper.append_op(
type="elementwise_add",
inputs={"X": [global_pos_cnt],
"Y": [batch_pos]},
outputs={"Out": [global_pos_cnt]})
self.acc = global_pos_cnt / global_ins_cnt
self._global_metric_state_vars = dict()
self._global_metric_state_vars['ins_cnt'] = (global_ins_cnt.name,
"float32")
self._global_metric_state_vars['pos_cnt'] = (global_pos_cnt.name,
"float32")
metric_name = "Acc(Recall@%d)" % self.k
self.metrics = dict()
self.metrics["InsCnt"] = global_ins_cnt
self.metrics["RecallCnt"] = global_pos_cnt
self.metrics[metric_name] = self.acc
# self.metrics["batch_metrics"] = batch_metrics
def _calculate(self, global_metrics):
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
ins_cnt = global_metrics['ins_cnt'][0]
pos_cnt = global_metrics['pos_cnt'][0]
if ins_cnt == 0:
acc = 0
else:
acc = float(pos_cnt) / ins_cnt
return "InsCnt=%s RecallCnt=%s Acc(Recall@%d)=%s" % (
str(ins_cnt), str(pos_cnt), self.k, str(acc))
def get_result(self):
return self.metrics
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
import abc import abc
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import Variable
from paddlerec.core.metric import Metric
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
...@@ -39,6 +41,7 @@ class ModelBase(object): ...@@ -39,6 +41,7 @@ class ModelBase(object):
self._init_hyper_parameters() self._init_hyper_parameters()
self._env = config self._env = config
self._slot_inited = False self._slot_inited = False
self._clear_metrics = None
def _init_hyper_parameters(self): def _init_hyper_parameters(self):
pass pass
...@@ -109,8 +112,23 @@ class ModelBase(object): ...@@ -109,8 +112,23 @@ class ModelBase(object):
def get_infer_inputs(self): def get_infer_inputs(self):
return self._infer_data_var return self._infer_data_var
def get_clear_metrics(self):
if self._clear_metrics is not None:
return self._clear_metrics
self._clear_metrics = []
for key in self._infer_results:
if isinstance(self._infer_results[key], Metric):
self._clear_metrics.append(self._infer_results[key])
return self._clear_metrics
def get_infer_results(self): def get_infer_results(self):
return self._infer_results res = dict()
for key in self._infer_results:
if isinstance(self._infer_results[key], Metric):
res.update(self._infer_results[key].get_result())
elif isinstance(self._infer_results[key], Variable):
res[key] = self._infer_results[key]
return res
def get_avg_cost(self): def get_avg_cost(self):
"""R """R
...@@ -120,7 +138,13 @@ class ModelBase(object): ...@@ -120,7 +138,13 @@ class ModelBase(object):
def get_metrics(self): def get_metrics(self):
"""R """R
""" """
return self._metrics res = dict()
for key in self._metrics:
if isinstance(self._metrics[key], Metric):
res.update(self._metrics[key].get_result())
elif isinstance(self._metrics[key], Variable):
res[key] = self._metrics[key]
return res
def get_fetch_period(self): def get_fetch_period(self):
return self._fetch_interval return self._fetch_interval
......
...@@ -20,6 +20,7 @@ import numpy as np ...@@ -20,6 +20,7 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.metric import Metric
__all__ = [ __all__ = [
"RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner" "RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner"
...@@ -77,9 +78,10 @@ class RunnerBase(object): ...@@ -77,9 +78,10 @@ class RunnerBase(object):
name = "dataset." + reader_name + "." name = "dataset." + reader_name + "."
if envs.get_global_env(name + "type") == "DataLoader": if envs.get_global_env(name + "type") == "DataLoader":
self._executor_dataloader_train(model_dict, context) return self._executor_dataloader_train(model_dict, context)
else: else:
self._executor_dataset_train(model_dict, context) self._executor_dataset_train(model_dict, context)
return None
def _executor_dataset_train(self, model_dict, context): def _executor_dataset_train(self, model_dict, context):
reader_name = model_dict["dataset_name"] reader_name = model_dict["dataset_name"]
...@@ -137,8 +139,10 @@ class RunnerBase(object): ...@@ -137,8 +139,10 @@ class RunnerBase(object):
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
metrics_names = ["total_batch"]
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items(): for name, var in metrics.items():
metrics_names.append(name)
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name)) metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format) metrics_format = ", ".join(metrics_format)
...@@ -147,6 +151,7 @@ class RunnerBase(object): ...@@ -147,6 +151,7 @@ class RunnerBase(object):
reader.start() reader.start()
batch_id = 0 batch_id = 0
scope = context["model"][model_name]["scope"] scope = context["model"][model_name]["scope"]
result = None
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
try: try:
while True: while True:
...@@ -168,6 +173,10 @@ class RunnerBase(object): ...@@ -168,6 +173,10 @@ class RunnerBase(object):
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
if batch_id > 0:
result = dict(zip(metrics_names, metrics))
return result
def _get_dataloader_program(self, model_dict, context): def _get_dataloader_program(self, model_dict, context):
model_name = model_dict["name"] model_name = model_dict["name"]
if context["model"][model_name]["compiled_program"] == None: if context["model"][model_name]["compiled_program"] == None:
...@@ -221,6 +230,7 @@ class RunnerBase(object): ...@@ -221,6 +230,7 @@ class RunnerBase(object):
program = context["model"][model_name]["main_program"].clone() program = context["model"][model_name]["main_program"].clone()
_exe_strategy, _build_strategy = self._get_strategy(model_dict, _exe_strategy, _build_strategy = self._get_strategy(model_dict,
context) context)
program = fluid.compiler.CompiledProgram(program).with_data_parallel( program = fluid.compiler.CompiledProgram(program).with_data_parallel(
loss_name=model_class.get_avg_cost().name, loss_name=model_class.get_avg_cost().name,
build_strategy=_build_strategy, build_strategy=_build_strategy,
...@@ -335,11 +345,28 @@ class SingleRunner(RunnerBase): ...@@ -335,11 +345,28 @@ class SingleRunner(RunnerBase):
".epochs")) ".epochs"))
for epoch in range(epochs): for epoch in range(epochs):
for model_dict in context["phases"]: for model_dict in context["phases"]:
model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
seconds = end_time - begin_time seconds = end_time - begin_time
print("epoch {} done, use time: {}".format(epoch, seconds)) message = "epoch {} done, use time: {}".format(epoch, seconds)
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].calc_global_metrics(
None,
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
elif result is not None:
_str = "{}={}".format(key, result[key])
metrics_result.append(_str)
if len(metrics_result) > 0:
message += ", global metrics: " + ", ".join(metrics_result)
print(message)
with fluid.scope_guard(context["model"][model_dict["name"]][ with fluid.scope_guard(context["model"][model_dict["name"]][
"scope"]): "scope"]):
train_prog = context["model"][model_dict["name"]][ train_prog = context["model"][model_dict["name"]][
...@@ -361,12 +388,32 @@ class PSRunner(RunnerBase): ...@@ -361,12 +388,32 @@ class PSRunner(RunnerBase):
envs.get_global_env("runner." + context["runner_name"] + envs.get_global_env("runner." + context["runner_name"] +
".epochs")) ".epochs"))
model_dict = context["env"]["phase"][0] model_dict = context["env"]["phase"][0]
model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._metrics
for epoch in range(epochs): for epoch in range(epochs):
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
seconds = end_time - begin_time seconds = end_time - begin_time
print("epoch {} done, use time: {}".format(epoch, seconds)) message = "epoch {} done, use time: {}".format(epoch, seconds)
# TODO, wait for PaddleCloudRoleMaker supports gloo
from paddle.fluid.incubate.fleet.base.role_maker import GeneralRoleMaker
if context["fleet"] is not None and isinstance(context["fleet"],
GeneralRoleMaker):
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].calc_global_metrics(
context["fleet"],
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
elif result is not None:
_str = "{}={}".format(key, result[key])
metrics_result.append(_str)
if len(metrics_result) > 0:
message += ", global metrics: " + ", ".join(metrics_result)
print(message)
with fluid.scope_guard(context["model"][model_dict["name"]][ with fluid.scope_guard(context["model"][model_dict["name"]][
"scope"]): "scope"]):
train_prog = context["model"][model_dict["name"]][ train_prog = context["model"][model_dict["name"]][
...@@ -473,16 +520,33 @@ class SingleInferRunner(RunnerBase): ...@@ -473,16 +520,33 @@ class SingleInferRunner(RunnerBase):
def run(self, context): def run(self, context):
self._dir_check(context) self._dir_check(context)
self.epoch_model_name_list.sort()
for index, epoch_name in enumerate(self.epoch_model_name_list): for index, epoch_name in enumerate(self.epoch_model_name_list):
for model_dict in context["phases"]: for model_dict in context["phases"]:
model_class = context["model"][model_dict["name"]]["model"]
metrics = model_class._infer_results
self._load(context, model_dict, self._load(context, model_dict,
self.epoch_model_path_list[index]) self.epoch_model_path_list[index])
begin_time = time.time() begin_time = time.time()
self._run(context, model_dict) result = self._run(context, model_dict)
end_time = time.time() end_time = time.time()
seconds = end_time - begin_time seconds = end_time - begin_time
print("Infer {} of {} done, use time: {}".format(model_dict[ message = "Infer {} of epoch {} done, use time: {}".format(
"name"], epoch_name, seconds)) model_dict["name"], epoch_name, seconds)
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].calc_global_metrics(
None,
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
elif result is not None:
_str = "{}={}".format(key, result[key])
metrics_result.append(_str)
if len(metrics_result) > 0:
message += ", global metrics: " + ", ".join(metrics_result)
print(message)
context["status"] = "terminal_pass" context["status"] = "terminal_pass"
def _load(self, context, model_dict, model_path): def _load(self, context, model_dict, model_path):
...@@ -497,6 +561,10 @@ class SingleInferRunner(RunnerBase): ...@@ -497,6 +561,10 @@ class SingleInferRunner(RunnerBase):
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
fluid.io.load_persistables( fluid.io.load_persistables(
context["exe"], model_path, main_program=train_prog) context["exe"], model_path, main_program=train_prog)
clear_metrics = context["model"][model_dict["name"]][
"model"].get_clear_metrics()
for var in clear_metrics:
var.clear()
def _dir_check(self, context): def _dir_check(self, context):
dirname = envs.get_global_env( dirname = envs.get_global_env(
......
# 如何给模型增加Metric
## PaddleRec Metric使用示例
```
from paddlerec.core.model import ModelBase
from paddlerec.core.metrics import RecallK
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def net(self, inputs, is_infer=False):
...
acc = RecallK(input=logits, label=label, k=20)
self._metrics["Train_P@20"] = acc
```
## Metric类
### 成员变量
> _global_metric_state_vars(dict),
字典类型,用以存储metric计算过程中需要的中间状态变量。一般情况下,这些中间状态需要是Persistable=True的变量,所以会在模型保存的时候也会被保存下来。因此infer阶段需手动将这些中间状态值清零,进而保证预测结果的正确性。
### 成员函数
> clear(self, scope):
从scope中将self._global_metric_state_vars中的状态值全清零。该函数一般用在**infer**阶段开始的时候。用以保证预测指标的正确性。
> calc_global_metrics(self, fleet, scope=None):
将self._global_metric_state_vars中的状态值在所有训练节点上做all_reduce操作,进而下一步调用_calculate()函数计算全局指标。若fleet=None,则all_reduce的结果为自己本身,即单机全局指标计算。
> get_result(self): 返回训练过程中需要fetch,并定期打印至屏幕的变量。返回类型为dict。
## Metrics
### AUC
> AUC(input ,label, curve='ROC', num_thresholds=2**12 - 1, topk=1, slide_steps=1)
Auc,全称Area Under the Curve(AUC),该层根据前向输出和标签计算AUC,在二分类(binary classification)估计中广泛使用。在二分类(binary classification)中广泛使用。相关定义参考 https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve 。
#### 参数
- **input(Tensor|LoDTensor)**: 数据类型为float32,float64。浮点二维变量。输入为网络的预测值。shape为[batch_size, 2]。
- **label(Tensor|LoDTensor)**: 数据类型为int64,int32。输入为数据集的标签。shape为[batch_size, 1]。
- **curve(str)**: 曲线类型,可以为 ROC 或 PR,默认 ROC。
- **num_thresholds(int)**: 将roc曲线离散化时使用的临界值数。默认200。
- **topk(int)**: 取topk的输出值用于计算。
- **slide_steps(int)**: - 当计算batch auc时,不仅用当前步也用于先前步。slide_steps=1,表示用当前步;slide_steps = 3表示用当前步和前两步;slide_steps = 0,则用所有步。
#### 返回值
该指标训练过程中定期的变量有两个:
- **AUC**: 整体AUC值
- **BATCH_AUC**:当前batch的AUC值
### PrecisionRecall
> PrecisionRecall(input, label, class_num)
计算precison, recall, f1。
#### 参数
- **input(Tensor|LoDTensor)**: 数据类型为float32,float64。输入为网络的预测值。shape为[batch_size, class_num]
- **label(Tensor|LoDTensor)**: 数据类型为int32。输入为数据集的标签。shape为 [batch_size, 1]
- **class_num(int)**: 类别个数。
#### 返回值
- **[TP FP TN FN]**: 形状为[class_num, 4]的变量,用以表征每种类型的TP,FP,TN和FN值。TP=true positive, FP=false positive, TN=true negative, FN=false negative。若需计算每种类型的precison, recall,f1, 则可根据如下公式进行计算:
precision = TP / (TP + FP); recall = TP = TP / (TP + FN); F1 = 2 * precision * recall / (precision + recall)。
- **precision_recall_f1**: 形状为[6],分别代表[macro_avg_precision, macro_avg_recall, macro_avg_f1, micro_avg_precision, micro_avg_recall, micro_avg_f1],这里macro代表先计算每种类型的准确率,召回率,F1,然后求平均。micro代表先计算所有类型的整体TP,TN, FP, FN等中间值,然后在计算准确率,召回率,F1.
### RecallK
> RecallK(input, label, k=20)
TopK的召回准确率,对于任意一条样本来说,若前top_k个分类结果中包含正确分类标签,则视为正样本。
#### 参数
- **input(Tensor|LoDTensor)**: 数据类型为float32,float64。输入为网络的预测值。shape为[batch_size, class_dim]
- **label(Tensor|LoDTensor)**: 数据类型为int64,int32。输入为数据集的标签。shape为 [batch_size, 1]
- **k(int)**: 取每个类别中top_k个预测值用于计算召回准确率。
#### 返回值
- **InsCnt**:样本总数
- **RecallCnt**: topk可以正确被召回的样本数
- **Acc(Recall@k)**: RecallCnt/InsCnt,即Topk召回准确率。
## PairWise_PN
> PosNegRatio(pos_score, neg_score)
正逆序指标,一般用在输入是pairwise的模型中。例如输入既包含正样本,也包含负样本,模型需要去学习最大化正负样本打分的差异。
#### 参数
- **pos_score(Tensor|LoDTensor)**: 正样本的打分,数据类型为float32,float64。浮点二维变量,值的范围为[0,1]。
- **neg_score(Tensor|LoDTensor)**:负样本的打分。数据类型为float32,float64。浮点二维变量,值的范围为[0,1]。
#### 返回值
- **RightCnt**: pos_score > neg_score的样本数
- **WrongCnt**: pos_score <= neg_score的样本数
- **PN**: (RightCnt + 1.0) / (WrongCnt + 1.0), 正逆序,+1.0是为了避免除0错误。
### Customized_Metric
如果你需要在自定义metric,那么你需要按如下步骤操作:
1. 继承paddlerec.core.Metric,定义你的MyMetric类。
2. 在MyMetric的构造函数中,自定义Metric组网,声明self._global_metric_state_vars私有变量。
3. 定义_calculate(global_metrics),全局指标计算。该函数的输入globla_metrics,存储了self._global_metric_state_vars中所有中间状态变量的全局统计值。最终结果以str格式返回。
自定义Metric模版如下,你可以参考注释,或paddlerec.core.metrics下已经实现的precision_recall, auc, pairwise_pn, recall_k等指标的计算方式,自定义自己的Metric类。
```
from paddlerec.core.Metric import Metric
class MyMetric(Metric):
def __init__(self):
# 1. 自定义Metric组网
** 1. your code **
# 2. 设置中间状态字典
self._global_metric_state_vars = dict()
** 2. your code **
def get_result(self):
# 3. 定义训练过程中需要打印的变量,以字典格式返回
self. _metrics = dict()
** 3. your code **
def _calculate(self, global_metrics):
# 4. 全局指标计算,global_metrics为字典类型,存储了self._global_metric_state_vars中所有中间状态变量的全局统计值。返回格式为str。
** your code **
```
...@@ -113,6 +113,8 @@ def input_data(self, is_infer=False, **kwargs): ...@@ -113,6 +113,8 @@ def input_data(self, is_infer=False, **kwargs):
可以参考官方模型的示例学习net的构造方法。 可以参考官方模型的示例学习net的构造方法。
除可以使用Paddle的Metrics接口外,PaddleRec也统一封装了一些常见的Metrics评价指标,并允许开发者定义自己的Metrics类,相关文件参考[Metrics开发文档](metrics.md)
## 如何运行自定义模型 ## 如何运行自定义模型
记录`model.py`,`config.yaml`及数据读取`reader.py`的文件路径,建议置于同一文件夹下,如`/home/custom_model`下,更改`config.yaml`中的配置选项 记录`model.py`,`config.yaml`及数据读取`reader.py`的文件路径,建议置于同一文件夹下,如`/home/custom_model`下,更改`config.yaml`中的配置选项
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import AUC
import paddle
import paddle.fluid as fluid
class TestAUC(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.batch_nums = 3
self.datas = []
for i in range(self.batch_nums):
probs = np.random.uniform(0, 1.0,
(self.ins_num, 2)).astype('float32')
labels = np.random.choice(range(2), self.ins_num).reshape(
(self.ins_num, 1)).astype('int64')
self.datas.append((probs, labels))
self.place = fluid.core.CPUPlace()
self.num_thresholds = 2**12
python_auc = fluid.metrics.Auc(name="auc",
curve='ROC',
num_thresholds=self.num_thresholds)
for i in range(self.batch_nums):
python_auc.update(self.datas[i][0], self.datas[i][1])
self.auc = np.array(python_auc.eval())
def build_network(self):
predict = fluid.data(
name="predict", shape=[-1, 2], dtype='float32', lod_level=0)
label = fluid.data(
name="label", shape=[-1, 1], dtype='int64', lod_level=0)
auc = AUC(input=predict,
label=label,
num_thresholds=self.num_thresholds,
curve='ROC')
return auc
def test_forward(self):
precision_recall = self.build_network()
metrics = precision_recall.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(
fluid.default_main_program(),
feed={'predict': self.datas[i][0],
'label': self.datas[i][1]},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['AUC'], self.auc))
def test_exception(self):
self.assertRaises(Exception, AUC)
self.assertRaises(
Exception, AUC, input=self.datas[0][0], label=self.datas[0][1]),
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import PosNegRatio
import paddle
import paddle.fluid as fluid
class TestPosNegRatio(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.batch_nums = 3
self.datas = []
self.right_cnt = 0.0
self.wrong_cnt = 0.0
for i in range(self.batch_nums):
neg_score = np.random.uniform(0, 1.0,
(self.ins_num, 1)).astype('float32')
pos_score = np.random.uniform(0, 1.0,
(self.ins_num, 1)).astype('float32')
right_cnt = np.sum(np.less(neg_score, pos_score)).astype('int32')
wrong_cnt = np.sum(np.less_equal(pos_score, neg_score)).astype(
'int32')
self.right_cnt += float(right_cnt)
self.wrong_cnt += float(wrong_cnt)
self.datas.append((pos_score, neg_score))
self.place = fluid.core.CPUPlace()
def build_network(self):
pos_score = fluid.data(
name="pos_score", shape=[-1, 1], dtype='float32', lod_level=0)
neg_score = fluid.data(
name="neg_score", shape=[-1, 1], dtype='float32', lod_level=0)
pairwise_pn = PosNegRatio(pos_score=pos_score, neg_score=neg_score)
return pairwise_pn
def test_forward(self):
pn = self.build_network()
metrics = pn.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(fluid.default_main_program(),
feed={
'pos_score': self.datas[i][0],
'neg_score': self.datas[i][1]
},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['RightCnt'], self.right_cnt))
self.assertTrue(np.allclose(outs['WrongCnt'], self.wrong_cnt))
self.assertTrue(
np.allclose(outs['PN'],
np.array((self.right_cnt + 1.0) / (self.wrong_cnt + 1.0
))))
def test_exception(self):
self.assertRaises(Exception, PosNegRatio)
self.assertRaises(
Exception,
PosNegRatio,
pos_score=self.datas[0][0],
neg_score=self.datas[0][1]),
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import PrecisionRecall
import paddle
import paddle.fluid as fluid
def calc_precision(tp_count, fp_count):
if tp_count > 0.0 or fp_count > 0.0:
return tp_count / (tp_count + fp_count)
return 1.0
def calc_recall(tp_count, fn_count):
if tp_count > 0.0 or fn_count > 0.0:
return tp_count / (tp_count + fn_count)
return 1.0
def calc_f1_score(precision, recall):
if precision > 0.0 or recall > 0.0:
return 2 * precision * recall / (precision + recall)
return 0.0
def get_states(idxs, labels, cls_num, weights=None, batch_nums=1):
ins_num = idxs.shape[0]
# TP FP TN FN
states = np.zeros((cls_num, 4)).astype('float32')
for i in range(ins_num):
w = weights[i] if weights is not None else 1.0
idx = idxs[i][0]
label = labels[i][0]
if idx == label:
states[idx][0] += w
for j in range(cls_num):
states[j][2] += w
states[idx][2] -= w
else:
states[label][3] += w
states[idx][1] += w
for j in range(cls_num):
states[j][2] += w
states[label][2] -= w
states[idx][2] -= w
return states
def compute_metrics(states, cls_num):
total_tp_count = 0.0
total_fp_count = 0.0
total_fn_count = 0.0
macro_avg_precision = 0.0
macro_avg_recall = 0.0
for i in range(cls_num):
total_tp_count += states[i][0]
total_fp_count += states[i][1]
total_fn_count += states[i][3]
macro_avg_precision += calc_precision(states[i][0], states[i][1])
macro_avg_recall += calc_recall(states[i][0], states[i][3])
metrics = []
macro_avg_precision /= cls_num
macro_avg_recall /= cls_num
metrics.append(macro_avg_precision)
metrics.append(macro_avg_recall)
metrics.append(calc_f1_score(macro_avg_precision, macro_avg_recall))
micro_avg_precision = calc_precision(total_tp_count, total_fp_count)
metrics.append(micro_avg_precision)
micro_avg_recall = calc_recall(total_tp_count, total_fn_count)
metrics.append(micro_avg_recall)
metrics.append(calc_f1_score(micro_avg_precision, micro_avg_recall))
return np.array(metrics).astype('float32')
class TestPrecisionRecall(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.cls_num = 10
self.batch_nums = 3
self.datas = []
self.states = np.zeros((self.cls_num, 4)).astype('float32')
for i in range(self.batch_nums):
probs = np.random.uniform(0, 1.0, (self.ins_num,
self.cls_num)).astype('float32')
idxs = np.array(np.argmax(
probs, axis=1)).reshape(self.ins_num, 1).astype('int32')
labels = np.random.choice(range(self.cls_num),
self.ins_num).reshape(
(self.ins_num, 1)).astype('int32')
self.datas.append((probs, labels))
states = get_states(idxs, labels, self.cls_num)
self.states = np.add(self.states, states)
self.metrics = compute_metrics(self.states, self.cls_num)
self.place = fluid.core.CPUPlace()
def build_network(self):
predict = fluid.data(
name="predict",
shape=[-1, self.cls_num],
dtype='float32',
lod_level=0)
label = fluid.data(
name="label", shape=[-1, 1], dtype='int32', lod_level=0)
precision_recall = PrecisionRecall(
input=predict, label=label, class_num=self.cls_num)
return precision_recall
def test_forward(self):
precision_recall = self.build_network()
metrics = precision_recall.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(
fluid.default_main_program(),
feed={'predict': self.datas[i][0],
'label': self.datas[i][1]},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(np.allclose(outs['[TP FP TN FN]'], self.states))
self.assertTrue(np.allclose(outs['precision_recall_f1'], self.metrics))
def test_exception(self):
self.assertRaises(Exception, PrecisionRecall)
self.assertRaises(
Exception,
PrecisionRecall,
input=self.datas[0][0],
label=self.datas[0][1],
class_num=self.cls_num)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddlerec.core.metrics import RecallK
import paddle
import paddle.fluid as fluid
class TestRecallK(unittest.TestCase):
def setUp(self):
self.ins_num = 64
self.cls_num = 10
self.topk = 2
self.batch_nums = 3
self.datas = []
self.match_num = 0.0
for i in range(self.batch_nums):
z = np.random.uniform(0, 1.0, (self.ins_num,
self.cls_num)).astype('float32')
pred = np.exp(z) / sum(np.exp(z))
label = np.random.choice(range(self.cls_num),
self.ins_num).reshape(
(self.ins_num, 1)).astype('int64')
self.datas.append((pred, label))
max_k_preds = pred.argsort(
axis=1)[:, -self.topk:][:, ::-1] #top-k label
match_array = np.logical_or.reduce(max_k_preds == label, axis=1)
self.match_num += np.sum(match_array).astype('float32')
self.place = fluid.core.CPUPlace()
def build_network(self):
pred = fluid.data(
name="pred",
shape=[-1, self.cls_num],
dtype='float32',
lod_level=0)
label = fluid.data(
name="label", shape=[-1, 1], dtype='int64', lod_level=0)
recall_k = RecallK(input=pred, label=label, k=self.topk)
return recall_k
def test_forward(self):
net = self.build_network()
metrics = net.get_result()
fetch_vars = []
metric_keys = []
for item in metrics.items():
fetch_vars.append(item[1])
metric_keys.append(item[0])
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
for i in range(self.batch_nums):
outs = exe.run(
fluid.default_main_program(),
feed={'pred': self.datas[i][0],
'label': self.datas[i][1]},
fetch_list=fetch_vars,
return_numpy=True)
outs = dict(zip(metric_keys, outs))
self.assertTrue(
np.allclose(outs['InsCnt'], self.ins_num * self.batch_nums))
self.assertTrue(np.allclose(outs['RecallCnt'], self.match_num))
self.assertTrue(
np.allclose(outs['Acc(Recall@%d)' % (self.topk)],
np.array(self.match_num / (self.ins_num *
self.batch_nums))))
def test_exception(self):
self.assertRaises(Exception, RecallK)
self.assertRaises(
Exception, RecallK, input=self.datas[0][0],
label=self.datas[0][1]),
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册