提交 1dee0622 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add shared prog

上级 0a0d5bc0
......@@ -373,7 +373,7 @@ def build(config, main_prog, startup_prog, is_train=True, is_distributed=True):
return dataloader, fetchs
def compile(config, program, loss_name=None):
def compile(config, program, loss_name=None, share_prog=None):
"""
Compile the program
......@@ -381,6 +381,7 @@ def compile(config, program, loss_name=None):
config(dict): config
program(): the program which is wrapped by
loss_name(str): loss name
share_prog(): the shared program, used for evaluation during training
Returns:
compiled_program(): a compiled program
......@@ -392,6 +393,7 @@ def compile(config, program, loss_name=None):
exec_strategy.num_iteration_per_drop_scope = 10
compiled_program = fluid.CompiledProgram(program).with_data_parallel(
share_vars_from=share_prog,
loss_name=loss_name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
......
......@@ -101,13 +101,14 @@ def main(args):
train_reader = Reader(config, 'train')()
train_dataloader.set_sample_list_generator(train_reader, places)
compiled_train_prog = program.compile(config, train_prog,
train_fetchs['loss'][0].name)
if config.validate:
valid_reader = Reader(config, 'valid')()
valid_dataloader.set_sample_list_generator(valid_reader, places)
compiled_valid_prog = program.compile(config, valid_prog)
compiled_train_prog = program.compile(config, train_prog,
train_fetchs['loss'][0].name)
compiled_valid_prog = program.compile(
config, valid_prog, share_prog=compiled_train_prog)
if args.vdl_dir:
from visualdl import LogWriter
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册