From 2cbf686a4ba17af7e03dc0d86e7c1142eb78986a Mon Sep 17 00:00:00 2001 From: malin10 Date: Thu, 30 Jul 2020 17:26:12 +0800 Subject: [PATCH] bug fix --- core/metric.py | 25 ++++++++++++------- core/metrics/binary_class/auc.py | 14 +++++------ core/metrics/binary_class/precision_recall.py | 7 +++--- core/metrics/pairwise_pn.py | 2 +- core/metrics/recall_k.py | 2 +- core/trainers/framework/runner.py | 6 ++--- 6 files changed, 31 insertions(+), 25 deletions(-) diff --git a/core/metric.py b/core/metric.py index d621a06e..ae91cd6d 100755 --- a/core/metric.py +++ b/core/metric.py @@ -23,11 +23,13 @@ class Metric(object): __metaclass__ = abc.ABCMeta def __init__(self, config): - """ """ + """R + """ pass def clear(self, scope=None): - """ """ + """R + """ if scope is None: scope = fluid.global_scope() @@ -41,9 +43,13 @@ class Metric(object): data_array = np.zeros(var._get_dims()).astype(dtype) var.set(data_array, place) - def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"): - """ """ - input = np.array(scope.find_var(metric_name).get_tensor()) + def _get_global_metric_state(self, fleet, scope, metric_name, mode="sum"): + """R + """ + var = scope.find_var(metric_name) + if not var: + return None + input = np.array(var.get_tensor()) if fleet is None: return input fleet._role_maker._barrier_worker() @@ -54,8 +60,9 @@ class Metric(object): output = output.reshape(old_shape) return output - def cal_global_metrics(self, fleet, scope=None): - """ """ + def calc_global_metrics(self, fleet, scope=None): + """R + """ if scope is None: scope = fluid.global_scope() @@ -65,9 +72,9 @@ class Metric(object): global_metrics[key] = self.get_global_metric_state(fleet, scope, varname) - return self.calculate(global_metrics) + return self._calculate(global_metrics) - def calculate(self, global_metrics): + def _calculate(self, global_metrics): pass @abc.abstractmethod diff --git a/core/metrics/binary_class/auc.py b/core/metrics/binary_class/auc.py index b8473146..129b8bc7 100755 --- a/core/metrics/binary_class/auc.py +++ b/core/metrics/binary_class/auc.py @@ -74,7 +74,7 @@ class AUC(Metric): self.metrics["AUC"] = auc_out self.metrics["BATCH_AUC"] = batch_auc_out - def calculate_bucket_error(self, global_pos, global_neg): + def _calculate_bucket_error(self, global_pos, global_neg): """R """ num_bucket = len(global_pos) @@ -122,7 +122,7 @@ class AUC(Metric): bucket_error = error_sum / error_count if error_count > 0 else 0.0 return bucket_error - def calculate_auc(self, global_pos, global_neg): + def _calculate_auc(self, global_pos, global_neg): """R """ num_bucket = len(global_pos) @@ -148,7 +148,7 @@ class AUC(Metric): auc_value = area / (pos * neg) return auc_value - def calculate(self, global_metrics): + def _calculate(self, global_metrics): result = dict() for key in self._global_metric_state_vars: if key not in global_metrics: @@ -165,10 +165,10 @@ class AUC(Metric): result['copc'] = 0 result['mean_q'] = 0 else: - result['auc'] = self.calculate_auc(result['stat_pos'], - result['stat_neg']) - result['bucket_error'] = self.calculate_auc(result['stat_pos'], - result['stat_neg']) + result['auc'] = self._calculate_auc(result['stat_pos'], + result['stat_neg']) + result['bucket_error'] = self._calculate_bucket_error( + result['stat_pos'], result['stat_neg']) result['actual_ctr'] = result['pos_ins_num'] / result[ 'total_ins_num'] result['mae'] = result['abserr'] / result['total_ins_num'] diff --git a/core/metrics/binary_class/precision_recall.py b/core/metrics/binary_class/precision_recall.py index 0eb80765..a40b1e19 100755 --- a/core/metrics/binary_class/precision_recall.py +++ b/core/metrics/binary_class/precision_recall.py @@ -29,7 +29,8 @@ class PrecisionRecall(Metric): """ def __init__(self, **kwargs): - """ """ + """R + """ if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs: raise ValueError( "PrecisionRecall expect input, label and class_num as inputs.") @@ -107,9 +108,7 @@ class PrecisionRecall(Metric): self.metrics["precision_recall_f1"] = accum_metrics self.metrics["[TP FP TN FN]"] = states_info - # self.metrics["batch_metrics"] = batch_metrics - - def calculate(self, global_metrics): + def _calculate(self, global_metrics): for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/metrics/pairwise_pn.py b/core/metrics/pairwise_pn.py index 673ce79b..156a8606 100755 --- a/core/metrics/pairwise_pn.py +++ b/core/metrics/pairwise_pn.py @@ -84,7 +84,7 @@ class PosNegRatio(Metric): self.metrics['RightCnt'] = global_right_cnt self.metrics['PN'] = self.pn - def calculate(self, global_metrics): + def _calculate(self, global_metrics): for key in self._global_communicate_var: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/metrics/recall_k.py b/core/metrics/recall_k.py index f570ef22..27ade145 100755 --- a/core/metrics/recall_k.py +++ b/core/metrics/recall_k.py @@ -88,7 +88,7 @@ class RecallK(Metric): self.metrics[metric_name] = self.acc # self.metrics["batch_metrics"] = batch_metrics - def calculate(self, global_metrics): + def _calculate(self, global_metrics): for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 91164a65..5da1f5df 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -356,7 +356,7 @@ class SingleRunner(RunnerBase): metrics_result = [] for key in metrics: if isinstance(metrics[key], Metric): - _str = metrics[key].cal_global_metrics( + _str = metrics[key].calc_global_metrics( None, context["model"][model_dict["name"]]["scope"]) metrics_result.append(_str) @@ -404,7 +404,7 @@ class PSRunner(RunnerBase): metrics_result = [] for key in metrics: if isinstance(metrics[key], Metric): - _str = metrics[key].cal_global_metrics( + _str = metrics[key].calc_global_metrics( context["fleet"], context["model"][model_dict["name"]]["scope"]) metrics_result.append(_str) @@ -536,7 +536,7 @@ class SingleInferRunner(RunnerBase): metrics_result = [] for key in metrics: if isinstance(metrics[key], Metric): - _str = metrics[key].cal_global_metrics( + _str = metrics[key].calc_global_metrics( None, context["model"][model_dict["name"]]["scope"]) metrics_result.append(_str) -- GitLab