diff --git a/tools/train.py b/tools/train.py index 4384b1528f4418035ca03218aeb7fa62cd94ac4b..a5f765f066bfefcf419cb78518a4b58d870c326c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -19,7 +19,6 @@ from __future__ import print_function import argparse import os -from visualdl import LogWriter import paddle.fluid as fluid from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.collective import fleet @@ -101,7 +100,12 @@ def main(args): compiled_valid_prog = program.compile(config, valid_prog) compiled_train_prog = fleet.main_program - vdl_writer = LogWriter(args.vdl_dir) if args.vdl_dir else None + + if args.vdl_dir: + from visualdl import LogWriter + vdl_writer = LogWriter(args.vdl_dir) + else: + vdl_writer = None for epoch_id in range(config.epochs): # 1. train with train dataset