提交 cf40ed6f 编写于 作者: S shippingwang

fixed

上级 62772c11
...@@ -81,10 +81,8 @@ def error(fmt, *args): ...@@ -81,10 +81,8 @@ def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args) _logger.error(coloring(fmt, "FAIL"), *args)
def scaler(name, value, step, path): def scaler(name, value, step, writer):
from visualdl import LogWriter writer.add_scalar(name, value, step)
vdl_writer = LogWriter(path)
vdl_writer.add_scalar(name, value, step)
def advertise(): def advertise():
......
...@@ -387,7 +387,7 @@ def compile(config, program, loss_name=None): ...@@ -387,7 +387,7 @@ def compile(config, program, loss_name=None):
total_step = 0 total_step = 0
def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None): def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_writer=None):
""" """
Feed data to the model and fetch the measures and loss Feed data to the model and fetch the measures and loss
...@@ -415,9 +415,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None): ...@@ -415,9 +415,9 @@ def run(dataloader, exe, program, fetchs, epoch=0, mode='train', vdl_dir=None):
metric_list[i].update(m[0], len(batch[0])) metric_list[i].update(m[0], len(batch[0]))
fetchs_str = ''.join([str(m.value) + ' ' fetchs_str = ''.join([str(m.value) + ' '
for m in metric_list] + [batch_time.value]) + 's' for m in metric_list] + [batch_time.value]) + 's'
if vdl_dir: if vdl_writer:
global total_step global total_step
logger.scaler('loss', metrics[0][0], total_step, vdl_dir) logger.scaler('loss', metrics[0][0], total_step, vdl_writer)
total_step += 1 total_step += 1
if mode == 'eval': if mode == 'eval':
logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str)) logger.info("{:s} step:{:<4d} {:s}s".format(mode, idx, fetchs_str))
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import argparse import argparse
import os import os
from visualdl import LogWriter
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.collective import fleet
...@@ -96,10 +97,12 @@ def main(args): ...@@ -96,10 +97,12 @@ def main(args):
compiled_valid_prog = program.compile(config, valid_prog) compiled_valid_prog = program.compile(config, valid_prog)
compiled_train_prog = fleet.main_program compiled_train_prog = fleet.main_program
vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None
for epoch_id in range(config.epochs): for epoch_id in range(config.epochs):
# 1. train with train dataset # 1. train with train dataset
program.run(train_dataloader, exe, compiled_train_prog, train_fetchs, program.run(train_dataloader, exe, compiled_train_prog, train_fetchs,
epoch_id, 'train', args.vdl_dir) epoch_id, 'train', vdl_writer)
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0: if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
# 2. validate with validate dataset # 2. validate with validate dataset
if config.validate and epoch_id % config.valid_interval == 0: if config.validate and epoch_id % config.valid_interval == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册