提交 44ba6942 编写于 作者: L luotao1

put clone(for_test=True) before optimization phase

上级 abf019f6
...@@ -437,11 +437,8 @@ def split_data(data, num_part): ...@@ -437,11 +437,8 @@ def split_data(data, num_part):
] ]
def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, def test_context(test_program, avg_cost, train_exe, dev_count, data_input_names,
sum_cost, token_num): sum_cost, token_num):
# Context to do validation.
test_program = train_progm.clone(for_test=True)
val_data = DataReader( val_data = DataReader(
src_vocab_fpath=TrainTaskConfig.src_vocab_fpath, src_vocab_fpath=TrainTaskConfig.src_vocab_fpath,
trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath, trg_vocab_fpath=TrainTaskConfig.trg_vocab_fpath,
...@@ -503,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names, ...@@ -503,7 +500,7 @@ def test_context(train_progm, avg_cost, train_exe, dev_count, data_input_names,
def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
token_num, predict): token_num, predict, test_program):
# Initialize the parameters. # Initialize the parameters.
if TrainTaskConfig.ckpt_path: if TrainTaskConfig.ckpt_path:
lr_scheduler.current_steps = TrainTaskConfig.start_step lr_scheduler.current_steps = TrainTaskConfig.start_step
...@@ -552,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -552,7 +549,7 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
-1] + label_data_input_fields -1] + label_data_input_fields
if TrainTaskConfig.val_file_pattern is not None: if TrainTaskConfig.val_file_pattern is not None:
test = test_context(train_progm, avg_cost, train_exe, dev_count, test = test_context(test_program, avg_cost, train_exe, dev_count,
data_input_names, sum_cost, token_num) data_input_names, sum_cost, token_num)
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
...@@ -1645,6 +1642,8 @@ def get_model(is_dist, is_async): ...@@ -1645,6 +1642,8 @@ def get_model(is_dist, is_async):
local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model, local_lr_scheduler = LearningRateScheduler(ModelHyperParams.d_model,
TrainTaskConfig.warmup_steps, TrainTaskConfig.warmup_steps,
TrainTaskConfig.learning_rate) TrainTaskConfig.learning_rate)
# Context to do validation.
test_program = fluid.default_main_program().clone(for_test=True)
if not is_dist: if not is_dist:
optimizer = fluid.optimizer.Adam( optimizer = fluid.optimizer.Adam(
...@@ -1669,7 +1668,7 @@ def get_model(is_dist, is_async): ...@@ -1669,7 +1668,7 @@ def get_model(is_dist, is_async):
epsilon=TrainTaskConfig.eps) epsilon=TrainTaskConfig.eps)
optimizer.minimize(sum_cost) optimizer.minimize(sum_cost)
return sum_cost, avg_cost, predict, token_num, local_lr_scheduler return sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program
def update_args(): def update_args():
...@@ -1703,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1703,7 +1702,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def run_trainer(self, use_cuda, args): def run_trainer(self, use_cuda, args):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
TrainTaskConfig.use_gpu = use_cuda TrainTaskConfig.use_gpu = use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
if args.is_dist: if args.is_dist:
...@@ -1724,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1724,7 +1723,7 @@ class DistTransformer2x2(TestDistRunnerBase):
TrainTaskConfig.local = not args.is_dist TrainTaskConfig.local = not args.is_dist
train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost, train_loop(startup_exe, trainer_prog, 1, sum_cost, avg_cost,
local_lr_scheduler, token_num, predict) local_lr_scheduler, token_num, predict, test_program)
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册