提交 2949a20f 编写于 作者: M malin10

rename global_metric_state_vars

上级 c1a4a6b8
...@@ -32,15 +32,16 @@ class Metric(object): ...@@ -32,15 +32,16 @@ class Metric(object):
scope = fluid.global_scope() scope = fluid.global_scope()
place = fluid.CPUPlace() place = fluid.CPUPlace()
for key in self._global_communicate_var: for key in self._global_metric_state_vars:
varname, dtype = self._global_communicate_var[key] varname, dtype = self._global_metric_state_vars[key]
if scope.find_var(varname) is None: var = scope.find_var(varname)
if not var:
continue continue
var = scope.var(varname).get_tensor() var = var.get_tensor()
data_array = np.zeros(var._get_dims()).astype(dtype) data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place) var.set(data_array, place)
def get_global_metric(self, fleet, scope, metric_name, mode="sum"): def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"):
""" """ """ """
input = np.array(scope.find_var(metric_name).get_tensor()) input = np.array(scope.find_var(metric_name).get_tensor())
if fleet is None: if fleet is None:
...@@ -59,9 +60,10 @@ class Metric(object): ...@@ -59,9 +60,10 @@ class Metric(object):
scope = fluid.global_scope() scope = fluid.global_scope()
global_metrics = dict() global_metrics = dict()
for key in self._global_communicate_var: for key in self._global_metric_state_vars:
varname, dtype = self._global_communicate_var[key] varname, dtype = self._global_metric_state_vars[key]
global_metrics[key] = self.get_global_metric(fleet, scope, varname) global_metrics[key] = self.get_global_metric_state(fleet, scope,
varname)
return self.calculate(global_metrics) return self.calculate(global_metrics)
......
...@@ -59,15 +59,16 @@ class AUC(Metric): ...@@ -59,15 +59,16 @@ class AUC(Metric):
sqrerr, abserr, prob, q, pos, total = \ sqrerr, abserr, prob, q, pos, total = \
fluid.contrib.layers.ctr_metric_bundle(prob, label_cast) fluid.contrib.layers.ctr_metric_bundle(prob, label_cast)
self._global_communicate_var = dict() self._global_metric_state_vars = dict()
self._global_communicate_var['stat_pos'] = (stat_pos.name, "float32") self._global_metric_state_vars['stat_pos'] = (stat_pos.name, "float32")
self._global_communicate_var['stat_neg'] = (stat_neg.name, "float32") self._global_metric_state_vars['stat_neg'] = (stat_neg.name, "float32")
self._global_communicate_var['total_ins_num'] = (total.name, "float32") self._global_metric_state_vars['total_ins_num'] = (total.name,
self._global_communicate_var['pos_ins_num'] = (pos.name, "float32") "float32")
self._global_communicate_var['q'] = (q.name, "float32") self._global_metric_state_vars['pos_ins_num'] = (pos.name, "float32")
self._global_communicate_var['prob'] = (prob.name, "float32") self._global_metric_state_vars['q'] = (q.name, "float32")
self._global_communicate_var['abserr'] = (abserr.name, "float32") self._global_metric_state_vars['prob'] = (prob.name, "float32")
self._global_communicate_var['sqrerr'] = (sqrerr.name, "float32") self._global_metric_state_vars['abserr'] = (abserr.name, "float32")
self._global_metric_state_vars['sqrerr'] = (sqrerr.name, "float32")
self.metrics = dict() self.metrics = dict()
self.metrics["AUC"] = auc_out self.metrics["AUC"] = auc_out
...@@ -149,7 +150,7 @@ class AUC(Metric): ...@@ -149,7 +150,7 @@ class AUC(Metric):
def calculate(self, global_metrics): def calculate(self, global_metrics):
result = dict() result = dict()
for key in self._global_communicate_var: for key in self._global_metric_state_vars:
if key not in global_metrics: if key not in global_metrics:
raise ValueError("%s not existed" % key) raise ValueError("%s not existed" % key)
result[key] = global_metrics[key][0] result[key] = global_metrics[key][0]
......
...@@ -99,8 +99,8 @@ class PrecisionRecall(Metric): ...@@ -99,8 +99,8 @@ class PrecisionRecall(Metric):
batch_states.stop_gradient = True batch_states.stop_gradient = True
states_info.stop_gradient = True states_info.stop_gradient = True
self._global_communicate_var = dict() self._global_metric_state_vars = dict()
self._global_communicate_var['states_info'] = (states_info.name, self._global_metric_state_vars['states_info'] = (states_info.name,
"float32") "float32")
self.metrics = dict() self.metrics = dict()
...@@ -110,7 +110,7 @@ class PrecisionRecall(Metric): ...@@ -110,7 +110,7 @@ class PrecisionRecall(Metric):
# self.metrics["batch_metrics"] = batch_metrics # self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics): def calculate(self, global_metrics):
for key in self._global_communicate_var: for key in self._global_metric_state_vars:
if key not in global_metrics: if key not in global_metrics:
raise ValueError("%s not existed" % key) raise ValueError("%s not existed" % key)
......
...@@ -73,10 +73,10 @@ class PosNegRatio(Metric): ...@@ -73,10 +73,10 @@ class PosNegRatio(Metric):
outputs={"Out": [global_wrong_cnt]}) outputs={"Out": [global_wrong_cnt]})
self.pn = (global_right_cnt + 1.0) / (global_wrong_cnt + 1.0) self.pn = (global_right_cnt + 1.0) / (global_wrong_cnt + 1.0)
self._global_communicate_var = dict() self._global_metric_state_vars = dict()
self._global_communicate_var['right_cnt'] = (global_right_cnt.name, self._global_metric_state_vars['right_cnt'] = (global_right_cnt.name,
"float32") "float32")
self._global_communicate_var['wrong_cnt'] = (global_wrong_cnt.name, self._global_metric_state_vars['wrong_cnt'] = (global_wrong_cnt.name,
"float32") "float32")
self.metrics = dict() self.metrics = dict()
......
...@@ -75,10 +75,10 @@ class RecallK(Metric): ...@@ -75,10 +75,10 @@ class RecallK(Metric):
self.acc = global_pos_cnt / global_ins_cnt self.acc = global_pos_cnt / global_ins_cnt
self._global_communicate_var = dict() self._global_metric_state_vars = dict()
self._global_communicate_var['ins_cnt'] = (global_ins_cnt.name, self._global_metric_state_vars['ins_cnt'] = (global_ins_cnt.name,
"float32") "float32")
self._global_communicate_var['pos_cnt'] = (global_pos_cnt.name, self._global_metric_state_vars['pos_cnt'] = (global_pos_cnt.name,
"float32") "float32")
metric_name = "Acc(Recall@%d)" % self.k metric_name = "Acc(Recall@%d)" % self.k
...@@ -89,7 +89,7 @@ class RecallK(Metric): ...@@ -89,7 +89,7 @@ class RecallK(Metric):
# self.metrics["batch_metrics"] = batch_metrics # self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics): def calculate(self, global_metrics):
for key in self._global_communicate_var: for key in self._global_metric_state_vars:
if key not in global_metrics: if key not in global_metrics:
raise ValueError("%s not existed" % key) raise ValueError("%s not existed" % key)
ins_cnt = global_metrics['ins_cnt'][0] ins_cnt = global_metrics['ins_cnt'][0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册