From 2946d3120fb7efc6374100cb70562d781da087bf Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 16 Apr 2018 15:47:21 +0800 Subject: [PATCH] Update --- fluid/neural_machine_translation/transformer/config.py | 2 +- fluid/neural_machine_translation/transformer/train.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/config.py b/fluid/neural_machine_translation/transformer/config.py index 6e49a11b..bb8cbcf2 100644 --- a/fluid/neural_machine_translation/transformer/config.py +++ b/fluid/neural_machine_translation/transformer/config.py @@ -4,7 +4,7 @@ class TrainTaskConfig(object): pass_num = 20 # the number of sequences contained in a mini-batch. - batch_size = 64 + batch_size = 32 # the hyper parameters for Adam optimizer. learning_rate = 0.001 diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 68ac015a..d82a233a 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -196,24 +196,26 @@ def main(): loss_name=avg_cost.name if TrainTaskConfig.use_avg_cost else sum_cost.name) - local_scopes = train_exe.executor.local_scopes() dev_count = fluid.core.get_cuda_device_count() for pos_enc_param_name in pos_enc_param_names: tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model) - for place_id, local_scope in enumerate(local_scopes): + for place_id in xrange(dev_count): + local_scope = train_exe.executor.local_scope(place_id) local_scope.find_var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id)) train_data = read_multiple(reader=train_data, count=dev_count) for pass_id in xrange(TrainTaskConfig.pass_num): pass_start_time = time.time() for batch_id, data in enumerate(train_data()): - for place_id, data_buffer, local_scope in zip(range(len(data)), data, local_scopes): + for place_id, data_buffer in enumerate(data): data_input_dict, util_input_dict = prepare_batch_input( data_buffer, data_input_names, util_input_names, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head, ModelHyperParams.d_model) + local_scope = train_exe.executor.local_scope(place_id) + local_scope.find_var(lr_scheduler.learning_rate.name).get_tensor().set( lr_scheduler.update_learning_rate(), fluid.CUDAPlace(place_id)) -- GitLab