diff --git a/tools/eval.py b/tools/eval.py index ea0d8c309a9c657aecff95960a52748b4f5f34e3..db5ce4eec1d55ae4a4c45f2f4def4b7429b7d4ca 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -71,9 +71,8 @@ def main(args): valid_reader = Reader(config, 'valid')() valid_dataloader.set_sample_list_generator(valid_reader, place) - - #compiled_valid_prog = program.compile(config, valid_prog) - compiled_valid_prog = valid_prog + + compiled_valid_prog = program.compile(config, valid_prog) program.run(valid_dataloader, exe, compiled_valid_prog, valid_fetchs, 0, 'valid') diff --git a/tools/train.py b/tools/train.py index effd6bc5585a70378ba730ab4a1286bd05a344cc..9cbe3edd844ca192eec0d1fe305fba33a45530b2 100644 --- a/tools/train.py +++ b/tools/train.py @@ -24,7 +24,6 @@ from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.collective import fleet from ppcls.data import Reader -from ppcls.utils import logger from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model, save_model import program @@ -86,8 +85,8 @@ def main(args): if config.validate: valid_reader = Reader(config, 'valid')() valid_dataloader.set_sample_list_generator(valid_reader, place) - #compiled_valid_prog = program.compile(config, valid_prog) - compiled_valid_prog = valid_prog + compiled_valid_prog = program.compile(config, valid_prog) + compiled_train_prog = fleet.main_program for epoch_id in range(config.epochs): # 1. train with train dataset