From 6aabfcb5f8af20abf7b871af8a726be252890a2b Mon Sep 17 00:00:00 2001 From: luotao1 Date: Thu, 20 Sep 2018 17:07:22 +0800 Subject: [PATCH] put clone(for_test=True) before optimization phase --- fluid/chinese_ner/train.py | 4 ++-- .../transformer/train.py | 15 +++++++-------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/fluid/chinese_ner/train.py b/fluid/chinese_ner/train.py index 5ce04b8c..7e59d2ed 100644 --- a/fluid/chinese_ner/train.py +++ b/fluid/chinese_ner/train.py @@ -270,6 +270,8 @@ def main(args): crf_decode = fluid.layers.crf_decoding( input=feature_out, param_attr=fluid.ParamAttr(name='crfw')) + inference_program = fluid.default_main_program().clone(for_test=True) + sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-3) sgd_optimizer.minimize(avg_cost) @@ -282,8 +284,6 @@ def main(args): chunk_evaluator = fluid.metrics.ChunkEvaluator() - inference_program = fluid.default_main_program().clone(for_test=True) - train_reader = paddle.batch( paddle.reader.shuffle( reader.file_reader(args.train_data_dir), buf_size=2000000), diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index aee780ed..d62c6bdd 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -257,11 +257,8 @@ def split_data(data, num_part): ] -def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, +def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num): - # Context to do validation. - test_program = train_progm.clone(for_test=True) - val_data = reader.DataReader( src_vocab_fpath=args.src_vocab_fpath, trg_vocab_fpath=args.trg_vocab_fpath, @@ -315,7 +312,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, - token_num, predict): + token_num, predict, test_program): # Initialize the parameters. if TrainTaskConfig.ckpt_path: fluid.io.load_persistables(exe, TrainTaskConfig.ckpt_path) @@ -360,7 +357,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, -1] + label_data_input_fields if args.val_file_pattern is not None: - test = test_context(train_progm, avg_cost, train_exe, dev_count, + test = test_context(test_program, avg_cost, train_exe, dev_count, data_input_names, sum_cost, token_num) # the best cross-entropy value with label smoothing @@ -484,6 +481,8 @@ def train(args): TrainTaskConfig.warmup_steps, TrainTaskConfig.learning_rate) + test_program = fluid.default_main_program().clone(for_test=True) + if args.local: optimizer = fluid.optimizer.Adam( learning_rate=lr_scheduler.learning_rate, @@ -511,7 +510,7 @@ def train(args): print("local start_up:") train_loop(exe, fluid.default_main_program(), dev_count, sum_cost, avg_cost, - lr_scheduler, token_num, predict) + lr_scheduler, token_num, predict, test_program) else: port = os.getenv("PADDLE_PORT", "6174") pserver_ips = os.getenv("PADDLE_PSERVERS") # ip,ip... @@ -548,7 +547,7 @@ def train(args): with open('trainer_prog.desc', 'w') as f: f.write(str(trainer_prog)) train_loop(exe, trainer_prog, dev_count, sum_cost, avg_cost, - lr_scheduler, token_num, predict) + lr_scheduler, token_num, predict, test_program) else: print("environment var TRAINER_ROLE should be TRAINER os PSERVER") -- GitLab