diff --git a/core/metric.py b/core/metric.py index 6ba39a1043e157b444575a689071611123bbb9e3..d621a06ead5fa85b0bb9d8f3e13a8f15aa6dafa5 100755 --- a/core/metric.py +++ b/core/metric.py @@ -32,15 +32,16 @@ class Metric(object): scope = fluid.global_scope() place = fluid.CPUPlace() - for key in self._global_communicate_var: - varname, dtype = self._global_communicate_var[key] - if scope.find_var(varname) is None: + for key in self._global_metric_state_vars: + varname, dtype = self._global_metric_state_vars[key] + var = scope.find_var(varname) + if not var: continue - var = scope.var(varname).get_tensor() + var = var.get_tensor() data_array = np.zeros(var._get_dims()).astype(dtype) var.set(data_array, place) - def get_global_metric(self, fleet, scope, metric_name, mode="sum"): + def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"): """ """ input = np.array(scope.find_var(metric_name).get_tensor()) if fleet is None: @@ -59,9 +60,10 @@ class Metric(object): scope = fluid.global_scope() global_metrics = dict() - for key in self._global_communicate_var: - varname, dtype = self._global_communicate_var[key] - global_metrics[key] = self.get_global_metric(fleet, scope, varname) + for key in self._global_metric_state_vars: + varname, dtype = self._global_metric_state_vars[key] + global_metrics[key] = self.get_global_metric_state(fleet, scope, + varname) return self.calculate(global_metrics) diff --git a/core/metrics/binary_class/auc.py b/core/metrics/binary_class/auc.py index a2e54a2c40268cc7ec0678938c28adfd4a2f5655..b847314636df53be8bdabcdd58961470b61148f5 100755 --- a/core/metrics/binary_class/auc.py +++ b/core/metrics/binary_class/auc.py @@ -59,15 +59,16 @@ class AUC(Metric): sqrerr, abserr, prob, q, pos, total = \ fluid.contrib.layers.ctr_metric_bundle(prob, label_cast) - self._global_communicate_var = dict() - self._global_communicate_var['stat_pos'] = (stat_pos.name, "float32") - self._global_communicate_var['stat_neg'] = (stat_neg.name, "float32") - self._global_communicate_var['total_ins_num'] = (total.name, "float32") - self._global_communicate_var['pos_ins_num'] = (pos.name, "float32") - self._global_communicate_var['q'] = (q.name, "float32") - self._global_communicate_var['prob'] = (prob.name, "float32") - self._global_communicate_var['abserr'] = (abserr.name, "float32") - self._global_communicate_var['sqrerr'] = (sqrerr.name, "float32") + self._global_metric_state_vars = dict() + self._global_metric_state_vars['stat_pos'] = (stat_pos.name, "float32") + self._global_metric_state_vars['stat_neg'] = (stat_neg.name, "float32") + self._global_metric_state_vars['total_ins_num'] = (total.name, + "float32") + self._global_metric_state_vars['pos_ins_num'] = (pos.name, "float32") + self._global_metric_state_vars['q'] = (q.name, "float32") + self._global_metric_state_vars['prob'] = (prob.name, "float32") + self._global_metric_state_vars['abserr'] = (abserr.name, "float32") + self._global_metric_state_vars['sqrerr'] = (sqrerr.name, "float32") self.metrics = dict() self.metrics["AUC"] = auc_out @@ -149,7 +150,7 @@ class AUC(Metric): def calculate(self, global_metrics): result = dict() - for key in self._global_communicate_var: + for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) result[key] = global_metrics[key][0] diff --git a/core/metrics/binary_class/precision_recall.py b/core/metrics/binary_class/precision_recall.py index d4a7fe1fb49421b91a50339d0e43c817074258a3..0eb80765232f7c318c49b926f61c74434c163d5d 100755 --- a/core/metrics/binary_class/precision_recall.py +++ b/core/metrics/binary_class/precision_recall.py @@ -99,9 +99,9 @@ class PrecisionRecall(Metric): batch_states.stop_gradient = True states_info.stop_gradient = True - self._global_communicate_var = dict() - self._global_communicate_var['states_info'] = (states_info.name, - "float32") + self._global_metric_state_vars = dict() + self._global_metric_state_vars['states_info'] = (states_info.name, + "float32") self.metrics = dict() self.metrics["precision_recall_f1"] = accum_metrics @@ -110,7 +110,7 @@ class PrecisionRecall(Metric): # self.metrics["batch_metrics"] = batch_metrics def calculate(self, global_metrics): - for key in self._global_communicate_var: + for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/metrics/pairwise_pn.py b/core/metrics/pairwise_pn.py index 09a9ffbd41e0f03a5be8e26fff3e6960850e3ecc..673ce79b9e8f5b7ec73fc440cdc4d959747cae26 100755 --- a/core/metrics/pairwise_pn.py +++ b/core/metrics/pairwise_pn.py @@ -73,11 +73,11 @@ class PosNegRatio(Metric): outputs={"Out": [global_wrong_cnt]}) self.pn = (global_right_cnt + 1.0) / (global_wrong_cnt + 1.0) - self._global_communicate_var = dict() - self._global_communicate_var['right_cnt'] = (global_right_cnt.name, - "float32") - self._global_communicate_var['wrong_cnt'] = (global_wrong_cnt.name, - "float32") + self._global_metric_state_vars = dict() + self._global_metric_state_vars['right_cnt'] = (global_right_cnt.name, + "float32") + self._global_metric_state_vars['wrong_cnt'] = (global_wrong_cnt.name, + "float32") self.metrics = dict() self.metrics['WrongCnt'] = global_wrong_cnt diff --git a/core/metrics/recall_k.py b/core/metrics/recall_k.py index 89216a5211b0633c7fd73abeed4569d1be3c24bd..f570ef222f2cdb68cd5fc283539755ce97123905 100755 --- a/core/metrics/recall_k.py +++ b/core/metrics/recall_k.py @@ -75,11 +75,11 @@ class RecallK(Metric): self.acc = global_pos_cnt / global_ins_cnt - self._global_communicate_var = dict() - self._global_communicate_var['ins_cnt'] = (global_ins_cnt.name, - "float32") - self._global_communicate_var['pos_cnt'] = (global_pos_cnt.name, - "float32") + self._global_metric_state_vars = dict() + self._global_metric_state_vars['ins_cnt'] = (global_ins_cnt.name, + "float32") + self._global_metric_state_vars['pos_cnt'] = (global_pos_cnt.name, + "float32") metric_name = "Acc(Recall@%d)" % self.k self.metrics = dict() @@ -89,7 +89,7 @@ class RecallK(Metric): # self.metrics["batch_metrics"] = batch_metrics def calculate(self, global_metrics): - for key in self._global_communicate_var: + for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) ins_cnt = global_metrics['ins_cnt'][0]