提交 2949a20f 编写于 作者: M malin10

rename global_metric_state_vars

上级 c1a4a6b8
......@@ -32,15 +32,16 @@ class Metric(object):
scope = fluid.global_scope()
place = fluid.CPUPlace()
for key in self._global_communicate_var:
varname, dtype = self._global_communicate_var[key]
if scope.find_var(varname) is None:
for key in self._global_metric_state_vars:
varname, dtype = self._global_metric_state_vars[key]
var = scope.find_var(varname)
if not var:
continue
var = scope.var(varname).get_tensor()
var = var.get_tensor()
data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place)
def get_global_metric(self, fleet, scope, metric_name, mode="sum"):
def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"):
""" """
input = np.array(scope.find_var(metric_name).get_tensor())
if fleet is None:
......@@ -59,9 +60,10 @@ class Metric(object):
scope = fluid.global_scope()
global_metrics = dict()
for key in self._global_communicate_var:
varname, dtype = self._global_communicate_var[key]
global_metrics[key] = self.get_global_metric(fleet, scope, varname)
for key in self._global_metric_state_vars:
varname, dtype = self._global_metric_state_vars[key]
global_metrics[key] = self.get_global_metric_state(fleet, scope,
varname)
return self.calculate(global_metrics)
......
......@@ -59,15 +59,16 @@ class AUC(Metric):
sqrerr, abserr, prob, q, pos, total = \
fluid.contrib.layers.ctr_metric_bundle(prob, label_cast)
self._global_communicate_var = dict()
self._global_communicate_var['stat_pos'] = (stat_pos.name, "float32")
self._global_communicate_var['stat_neg'] = (stat_neg.name, "float32")
self._global_communicate_var['total_ins_num'] = (total.name, "float32")
self._global_communicate_var['pos_ins_num'] = (pos.name, "float32")
self._global_communicate_var['q'] = (q.name, "float32")
self._global_communicate_var['prob'] = (prob.name, "float32")
self._global_communicate_var['abserr'] = (abserr.name, "float32")
self._global_communicate_var['sqrerr'] = (sqrerr.name, "float32")
self._global_metric_state_vars = dict()
self._global_metric_state_vars['stat_pos'] = (stat_pos.name, "float32")
self._global_metric_state_vars['stat_neg'] = (stat_neg.name, "float32")
self._global_metric_state_vars['total_ins_num'] = (total.name,
"float32")
self._global_metric_state_vars['pos_ins_num'] = (pos.name, "float32")
self._global_metric_state_vars['q'] = (q.name, "float32")
self._global_metric_state_vars['prob'] = (prob.name, "float32")
self._global_metric_state_vars['abserr'] = (abserr.name, "float32")
self._global_metric_state_vars['sqrerr'] = (sqrerr.name, "float32")
self.metrics = dict()
self.metrics["AUC"] = auc_out
......@@ -149,7 +150,7 @@ class AUC(Metric):
def calculate(self, global_metrics):
result = dict()
for key in self._global_communicate_var:
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
result[key] = global_metrics[key][0]
......
......@@ -99,8 +99,8 @@ class PrecisionRecall(Metric):
batch_states.stop_gradient = True
states_info.stop_gradient = True
self._global_communicate_var = dict()
self._global_communicate_var['states_info'] = (states_info.name,
self._global_metric_state_vars = dict()
self._global_metric_state_vars['states_info'] = (states_info.name,
"float32")
self.metrics = dict()
......@@ -110,7 +110,7 @@ class PrecisionRecall(Metric):
# self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics):
for key in self._global_communicate_var:
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
......
......@@ -73,10 +73,10 @@ class PosNegRatio(Metric):
outputs={"Out": [global_wrong_cnt]})
self.pn = (global_right_cnt + 1.0) / (global_wrong_cnt + 1.0)
self._global_communicate_var = dict()
self._global_communicate_var['right_cnt'] = (global_right_cnt.name,
self._global_metric_state_vars = dict()
self._global_metric_state_vars['right_cnt'] = (global_right_cnt.name,
"float32")
self._global_communicate_var['wrong_cnt'] = (global_wrong_cnt.name,
self._global_metric_state_vars['wrong_cnt'] = (global_wrong_cnt.name,
"float32")
self.metrics = dict()
......
......@@ -75,10 +75,10 @@ class RecallK(Metric):
self.acc = global_pos_cnt / global_ins_cnt
self._global_communicate_var = dict()
self._global_communicate_var['ins_cnt'] = (global_ins_cnt.name,
self._global_metric_state_vars = dict()
self._global_metric_state_vars['ins_cnt'] = (global_ins_cnt.name,
"float32")
self._global_communicate_var['pos_cnt'] = (global_pos_cnt.name,
self._global_metric_state_vars['pos_cnt'] = (global_pos_cnt.name,
"float32")
metric_name = "Acc(Recall@%d)" % self.k
......@@ -89,7 +89,7 @@ class RecallK(Metric):
# self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics):
for key in self._global_communicate_var:
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
ins_cnt = global_metrics['ins_cnt'][0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册