提交 edf1129e 编写于 作者: Z zhiboniu

match new eval function

上级 699c10aa
...@@ -5,7 +5,7 @@ Global: ...@@ -5,7 +5,7 @@ Global:
output_dir: "./output/" output_dir: "./output/"
device: "gpu" device: "gpu"
save_interval: 5 save_interval: 5
eval_during_train: False eval_during_train: True
eval_interval: 1 eval_interval: 1
epochs: 30 epochs: 30
print_batch_step: 20 print_batch_step: 20
......
...@@ -18,7 +18,7 @@ import time ...@@ -18,7 +18,7 @@ import time
import platform import platform
import paddle import paddle
from ppcls.utils.misc import AverageMeter, AttrMeter from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger from ppcls.utils import logger
...@@ -34,10 +34,6 @@ def classification_eval(engine, epoch_id=0): ...@@ -34,10 +34,6 @@ def classification_eval(engine, epoch_id=0):
} }
print_batch_step = engine.config["Global"]["print_batch_step"] print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.eval_metric_func is not None and "ATTRMetric" in engine.config[
"Metric"]["Eval"][0]:
output_info["attr"] = AttrMeter(threshold=0.5)
metric_key = None metric_key = None
tic = time.time() tic = time.time()
accum_samples = 0 accum_samples = 0
...@@ -162,7 +158,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -162,7 +158,7 @@ def classification_eval(engine, epoch_id=0):
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]: if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([ metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}". "evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*output_info["attr"].res()) format(*engine.eval_metric_func.attr_res())
]) ])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg)) logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
...@@ -170,7 +166,7 @@ def classification_eval(engine, epoch_id=0): ...@@ -170,7 +166,7 @@ def classification_eval(engine, epoch_id=0):
if engine.eval_metric_func is None: if engine.eval_metric_func is None:
return -1 return -1
# return 1st metric in the dict # return 1st metric in the dict
return output_info["attr"].res()[0] return engine.eval_metric_func.attr_res()[0]
else: else:
metric_msg = ", ".join([ metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg) "{}: {:.5f}".format(key, output_info[key].avg)
......
...@@ -56,6 +56,9 @@ class CombinedMetrics(AvgMetrics): ...@@ -56,6 +56,9 @@ class CombinedMetrics(AvgMetrics):
def avg(self): def avg(self):
return self.metric_func_list[0].avg return self.metric_func_list[0].avg
def attr_res(self):
return self.metric_func_list[0].attrmeter.res()
def reset(self): def reset(self):
for metric in self.metric_func_list: for metric in self.metric_func_list:
if hasattr(metric, "reset"): if hasattr(metric, "reset"):
......
...@@ -25,7 +25,7 @@ from sklearn.preprocessing import binarize ...@@ -25,7 +25,7 @@ from sklearn.preprocessing import binarize
from easydict import EasyDict from easydict import EasyDict
from ppcls.metric.avg_metrics import AvgMetrics from ppcls.metric.avg_metrics import AvgMetrics
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter, AttrMeter
class TopkAcc(AvgMetrics): class TopkAcc(AvgMetrics):
...@@ -438,7 +438,11 @@ class ATTRMetric(nn.Layer): ...@@ -438,7 +438,11 @@ class ATTRMetric(nn.Layer):
super().__init__() super().__init__()
self.threshold = threshold self.threshold = threshold
def reset(self):
self.attrmeter = AttrMeter(threshold=0.5)
def forward(self, output, target): def forward(self, output, target):
metric_dict = get_attr_metrics(target[:, 0, :].numpy(), metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
output.numpy(), self.threshold) output.numpy(), self.threshold)
self.attrmeter.update(metric_dict)
return metric_dict return metric_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册