未验证 提交 8908f2e3 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #1683 from guoshengCS/fix-transformer-logging

Fix logging and checkpoint loading in Transformer
...@@ -408,10 +408,19 @@ def test_context(exe, train_exe, dev_count): ...@@ -408,10 +408,19 @@ def test_context(exe, train_exe, dev_count):
test_data = prepare_data_generator( test_data = prepare_data_generator(
args, is_test=True, count=dev_count, pyreader=pyreader) args, is_test=True, count=dev_count, pyreader=pyreader)
exe.run(startup_prog) exe.run(startup_prog) # to init pyreader for testing
if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(
exe, TrainTaskConfig.ckpt_path, main_program=test_prog)
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.use_experimental_executor = True
build_strategy = fluid.BuildStrategy()
test_exe = fluid.ParallelExecutor( test_exe = fluid.ParallelExecutor(
use_cuda=TrainTaskConfig.use_gpu, use_cuda=TrainTaskConfig.use_gpu,
main_program=test_prog, main_program=test_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy,
share_vars_from=train_exe) share_vars_from=train_exe)
def test(exe=test_exe, pyreader=pyreader): def test(exe=test_exe, pyreader=pyreader):
...@@ -457,7 +466,11 @@ def train_loop(exe, ...@@ -457,7 +466,11 @@ def train_loop(exe,
nccl2_trainer_id=0): nccl2_trainer_id=0):
# Initialize the parameters. # Initialize the parameters.
if TrainTaskConfig.ckpt_path: if TrainTaskConfig.ckpt_path:
fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path) exe.run(startup_prog) # to init pyreader for training
logging.info("load checkpoint from {}".format(
TrainTaskConfig.ckpt_path))
fluid.io.load_persistables(
exe, TrainTaskConfig.ckpt_path, main_program=train_prog)
else: else:
logging.info("init fluid.framework.default_startup_program") logging.info("init fluid.framework.default_startup_program")
exe.run(startup_prog) exe.run(startup_prog)
...@@ -741,6 +754,7 @@ if __name__ == "__main__": ...@@ -741,6 +754,7 @@ if __name__ == "__main__":
LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s" LOG_FORMAT = "[%(asctime)s %(levelname)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig( logging.basicConfig(
stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT) stream=sys.stdout, level=logging.DEBUG, format=LOG_FORMAT)
logging.getLogger().setLevel(logging.INFO)
args = parse_args() args = parse_args()
train(args) train(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册