未验证 提交 0189236b 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add support for vdl (#978)

* add support for vdl

* fix vdl eval
上级 f8a8c51e
...@@ -26,6 +26,7 @@ import argparse ...@@ -26,6 +26,7 @@ import argparse
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.distributed as dist import paddle.distributed as dist
from visualdl import LogWriter
from ppcls.utils.check import check_gpu from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter from ppcls.utils.misc import AverageMeter
...@@ -83,8 +84,7 @@ class Trainer(object): ...@@ -83,8 +84,7 @@ class Trainer(object):
self.model = paddle.DataParallel(self.model) self.model = paddle.DataParallel(self.model)
self.vdl_writer = None self.vdl_writer = None
if self.config['Global']['use_visualdl']: if self.config['Global']['use_visualdl'] and mode == "train":
from visualdl import LogWriter
vdl_writer_path = os.path.join(self.output_dir, "vdl") vdl_writer_path = os.path.join(self.output_dir, "vdl")
if not os.path.exists(vdl_writer_path): if not os.path.exists(vdl_writer_path):
os.makedirs(vdl_writer_path) os.makedirs(vdl_writer_path)
...@@ -219,6 +219,18 @@ class Trainer(object): ...@@ -219,6 +219,18 @@ class Trainer(object):
"epochs"], iter_id, "epochs"], iter_id,
len(self.train_dataloader), lr_msg, metric_msg, len(self.train_dataloader), lr_msg, metric_msg,
time_msg, ips_msg, eta_msg)) time_msg, ips_msg, eta_msg))
logger.scaler(
name="lr",
value=lr_sch.get_lr(),
step=global_step,
writer=self.vdl_writer)
for key in output_info:
logger.scaler(
name="train_{}".format(key),
value=output_info[key].avg,
step=global_step,
writer=self.vdl_writer)
tic = time.time() tic = time.time()
metric_msg = ", ".join([ metric_msg = ", ".join([
...@@ -246,6 +258,12 @@ class Trainer(object): ...@@ -246,6 +258,12 @@ class Trainer(object):
prefix="best_model") prefix="best_model")
logger.info("[Eval][Epoch {}][best metric: {}]".format( logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"])) epoch_id, best_metric["metric"]))
logger.scaler(
name="eval_acc",
value=acc,
step=epoch_id,
writer=self.vdl_writer)
self.model.train() self.model.train()
# save model # save model
...@@ -266,6 +284,9 @@ class Trainer(object): ...@@ -266,6 +284,9 @@ class Trainer(object):
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="latest") prefix="latest")
if self.vdl_writer is not None:
self.vdl_writer.close()
def build_avg_metrics(self, info_dict): def build_avg_metrics(self, info_dict):
return {key: AverageMeter(key, '7.5f') for key in info_dict} return {key: AverageMeter(key, '7.5f') for key in info_dict}
......
...@@ -102,6 +102,8 @@ def scaler(name, value, step, writer): ...@@ -102,6 +102,8 @@ def scaler(name, value, step, writer):
visualdl --logdir ./scalar --host 0.0.0.0 --port 8830 visualdl --logdir ./scalar --host 0.0.0.0 --port 8830
to preview loss corve in real time. to preview loss corve in real time.
""" """
if writer is None:
return
writer.add_scalar(tag=name, step=step, value=value) writer.add_scalar(tag=name, step=step, value=value)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册