提交 cf40ed6f 编写于 作者: S shippingwang

fixed

上级 62772c11
......@@ -81,10 +81,8 @@ def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args)
def scaler(name, value, step, path):
from visualdl import LogWriter
vdl_writer = LogWriter(path)
vdl_writer.add_scalar(name, value, step)
def scaler(name, value, step, writer):
writer.add_scalar(name, value, step)
def advertise():
......
......@@ -387,7 +387,7 @@ def compile(config, program, loss_name=None):
total_step = 0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
"""
Feed data to the model and fetch the measures and loss
......@@ -415,9 +415,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value]) + 's'
if vdl_dir:
if vdl_writer:
global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_dir)
logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1
if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
import argparse
import os
from visualdl import LogWriter
import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.incubate.fleet.collective import fleet
......@@ -96,10 +97,12 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog)
compiled_train_prog = fleet.main_program
vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None
for epoch_id in range(config.epochs):
# 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train', args.vdl_dir)
epoch_id, 'train', vdl_writer)
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册