提交 edf1129e 编写于 作者: Z zhiboniu

match new eval function

上级 699c10aa
......@@ -5,7 +5,7 @@ Global:
output_dir: "./output/"
device: "gpu"
save_interval: 5
eval_during_train: False
eval_during_train: True
eval_interval: 1
epochs: 30
print_batch_step: 20
......
......@@ -18,7 +18,7 @@ import time
import platform
import paddle
from ppcls.utils.misc import AverageMeter, AttrMeter
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
......@@ -34,10 +34,6 @@ def classification_eval(engine, epoch_id=0):
}
print_batch_step = engine.config["Global"]["print_batch_step"]
if engine.eval_metric_func is not None and "ATTRMetric" in engine.config[
"Metric"]["Eval"][0]:
output_info["attr"] = AttrMeter(threshold=0.5)
metric_key = None
tic = time.time()
accum_samples = 0
......@@ -162,7 +158,7 @@ def classification_eval(engine, epoch_id=0):
if "ATTRMetric" in engine.config["Metric"]["Eval"][0]:
metric_msg = ", ".join([
"evalres: ma: {:.5f} label_f1: {:.5f} label_pos_recall: {:.5f} label_neg_recall: {:.5f} instance_f1: {:.5f} instance_acc: {:.5f} instance_prec: {:.5f} instance_recall: {:.5f}".
format(*output_info["attr"].res())
format(*engine.eval_metric_func.attr_res())
])
logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
......@@ -170,7 +166,7 @@ def classification_eval(engine, epoch_id=0):
if engine.eval_metric_func is None:
return -1
# return 1st metric in the dict
return output_info["attr"].res()[0]
return engine.eval_metric_func.attr_res()[0]
else:
metric_msg = ", ".join([
"{}: {:.5f}".format(key, output_info[key].avg)
......
......@@ -56,6 +56,9 @@ class CombinedMetrics(AvgMetrics):
def avg(self):
return self.metric_func_list[0].avg
def attr_res(self):
return self.metric_func_list[0].attrmeter.res()
def reset(self):
for metric in self.metric_func_list:
if hasattr(metric, "reset"):
......
......@@ -25,7 +25,7 @@ from sklearn.preprocessing import binarize
from easydict import EasyDict
from ppcls.metric.avg_metrics import AvgMetrics
from ppcls.utils.misc import AverageMeter
from ppcls.utils.misc import AverageMeter, AttrMeter
class TopkAcc(AvgMetrics):
......@@ -438,7 +438,11 @@ class ATTRMetric(nn.Layer):
super().__init__()
self.threshold = threshold
def reset(self):
self.attrmeter = AttrMeter(threshold=0.5)
def forward(self, output, target):
metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
output.numpy(), self.threshold)
self.attrmeter.update(metric_dict)
return metric_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册